package org.hepeng.workx.util.crypto;

import lombok.Data;
import org.apache.commons.codec.binary.Base64;
import org.hepeng.workx.exception.ApplicationRuntimeException;
import org.springframework.util.Assert;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import java.nio.ByteBuffer;
import java.security.KeyFactory;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;

/**
 * RSACrypt 算法加密、解密
 * @author he peng
 */
public class RSACrypt {

    private static final String ALGORITHM_KEY = "RSA";
    private int encryptFactorSize = 1024;
    private String transFormat = "RSA/ECB/PKCS1Padding";
    private int maxEncryptByteBlock = 117;
    private int maxDecryptByteBlock = 128;

    private RSACrypt() {}

    private RSACrypt(int encryptFactorSize , String transFormat , int maxEncryptByteBlock , int maxDecryptByteBlock) {
        this.encryptFactorSize = encryptFactorSize;
        this.transFormat = transFormat;
        this.maxEncryptByteBlock = maxEncryptByteBlock;
        this.maxDecryptByteBlock = maxDecryptByteBlock;
    }

    /**
     * 创建一个新的使用默认参数的 {@link RSACrypt} 实例对象
     * @return 返回 {@link RSACrypt} 实例对象
     */
    public static RSACrypt newRSACrypt() {
        return new RSACrypt();
    }

    /**
     * 创建一个新的 {@link RSACrypt} 实例对象
     * @param encryptFactorSize     加密因子
     * @param transFormat           Cipher转换类型
     * @param maxEncryptByteBlock   最大加密字节大小
     * @param maxDecryptByteBlock   最大解密字节大小
     * @return 返回 {@link RSACrypt} 实例对象
     */
    public static RSACrypt newRSACrypt(int encryptFactorSize , String transFormat , int maxEncryptByteBlock , int maxDecryptByteBlock) {
        return new RSACrypt(encryptFactorSize , transFormat , maxEncryptByteBlock , maxDecryptByteBlock);
    }

    /**
     * 生成一对公钥，私钥
     * @return 返回一个 {@link KeyPair} 实例对象
     */
    public KeyPair generateKeyPair() {
        KeyPairGenerator keyPairGenerator;
        try {
            keyPairGenerator = KeyPairGenerator.getInstance(ALGORITHM_KEY);
        } catch (NoSuchAlgorithmException e) {
            throw new ApplicationRuntimeException(e);
        }
        keyPairGenerator.initialize(this.encryptFactorSize);
        java.security.KeyPair keyPair = keyPairGenerator.generateKeyPair();
        PrivateKey privateKey = keyPair.getPrivate();
        PublicKey publicKey = keyPair.getPublic();
        KeyPair kp = new KeyPair();
        kp.setPrivateKey(Base64.encodeBase64String(privateKey.getEncoded()));
        kp.setPublicKey(Base64.encodeBase64String(publicKey.getEncoded()));
        return kp;
    }

    private static PublicKey decodePublicKey(String publicKey) throws Exception {
        Assert.hasLength(publicKey , "publicKey Length == 0");
        X509EncodedKeySpec encodedKeySpec =
                new X509EncodedKeySpec(Base64.decodeBase64(publicKey));
        KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_KEY);
        return keyFactory.generatePublic(encodedKeySpec);
    }

    private static PrivateKey decodePrivateKey(String privateKey) throws Exception {
        Assert.hasLength(privateKey , "privateKey Length == 0");
        PKCS8EncodedKeySpec encodedKeySpec =
                new PKCS8EncodedKeySpec(Base64.decodeBase64(privateKey));
        KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_KEY);
        return keyFactory.generatePrivate(encodedKeySpec);
    }

    /**
     * 使用公钥进行加密
     * @param publicKey 公钥
     * @param content   要加密的数据
     * @return  返回加密后的字节数组
     */
    public byte[] encrypt(String publicKey , byte[] content) {
        byte[] bytes;
        try {
            PublicKey pk = decodePublicKey(publicKey);
            Cipher cipher = Cipher.getInstance(this.transFormat);
            cipher.init(Cipher.ENCRYPT_MODE , pk);
            bytes = safeCipher(content , maxEncryptByteBlock , cipher);
        } catch (Exception e) {
            throw new ApplicationRuntimeException("RSACrypt Encrypt Error : " + e.getMessage() , e);
        }
        return bytes;
    }

    /**
     * 使用公钥进行加密返回加密后的 base64 字符串
     * @param publicKey  公钥
     * @param content    要加密的数据
     * @return 返回加密后数据的 base64 字符串
     */
    public String encryptBase64(String publicKey , byte[] content) {
        byte[] bytes = encrypt(publicKey, content);
        return Base64.encodeBase64String(bytes);
    }

    /**
     * 使用私钥进行解密
     * @param privateKey    私钥
     * @param content       要解密的数据
     * @return  返回解密后的字节数组
     */
    public byte[] decrypt(String privateKey , byte[] content) {
        byte[] bytes;
        try {
            PrivateKey pk = decodePrivateKey(privateKey);
            Cipher cipher = Cipher.getInstance(transFormat);
            cipher.init(Cipher.DECRYPT_MODE , pk);
            bytes = safeCipher(content , maxDecryptByteBlock , cipher);
        } catch (Exception e) {
            throw new ApplicationRuntimeException("RSACrypt Decrypt Error : " + e.getMessage() , e);
        }
        return bytes;
    }

    /**
     * 使用私钥对加密后编码成的 base64 字符串进行解密
     * @param privateKey    私钥
     * @param base64String  要解密数据的 base64 字符串
     * @return  返回解密后的字节数组
     */
    public byte[] decryptBase64(String privateKey , String base64String) {
        byte[] bytes = Base64.decodeBase64(base64String);
        return decrypt(privateKey , bytes);
    }

    private static byte[] safeCipher(byte[] content , int size , Cipher cipher) throws BadPaddingException, IllegalBlockSizeException {
        byte[] targetBytes;
        if (content.length < size) {
            targetBytes = cipher.doFinal(content);
        } else {
            ByteBuffer byteBuffer = ByteBuffer.allocate(500);
            int inputLen = content.length;
            int offset = 0;
            for (;inputLen - offset > 0;) {
                int len;
                if (inputLen < (offset + size)) {
                    len = inputLen - offset;
                } else {
                    len = offset + size;
                }
                byte[] bytes = cipher.doFinal(content, offset, len);
                offset += size;
                byteBuffer.put(bytes);
            }
            int readPosition = byteBuffer.position();
            targetBytes = new byte[readPosition];
            byteBuffer.flip();
            byteBuffer.get(targetBytes , 0 , readPosition);
        }
        return targetBytes;
    }

    @Data
    public static class KeyPair {
        private String publicKey;
        private String privateKey;
    }
}
